import os
import numpy as np
import torch
from PIL import Image
from modules.tune.geowizard.models.geowizard_pipeline import DepthNormalEstimationPipeline

def depth_normal_estimation(
    input_tensor: torch.Tensor,
    checkpoint_path: str = 'lemonaddie/geowizard',
    domain: str = 'indoor',
    denoise_steps: int = 10,
    ensemble_size: int = 10,
    half_precision: bool = False,
    processing_res: int = 768,
    output_res_match_input: bool = True,
    color_map: str = 'Spectral',
):
    """
    Perform depth and normal estimation on an input tensor.

    Args:
        input_tensor (torch.Tensor): Input image tensor of shape (3, W, H).
        checkpoint_path (str): Path to the pretrained model checkpoint.
        domain (str): Domain prediction ('indoor', 'outdoor', etc.).
        denoise_steps (int): Number of diffusion denoising steps.
        ensemble_size (int): Number of predictions to ensemble.
        half_precision (bool): Whether to use half-precision (16-bit float).
        processing_res (int): Maximum resolution for processing the image.
        output_res_match_input (bool): Whether output depth matches input resolution.
        color_map (str): Colormap used to render depth predictions.
        seed (int): Random seed for reproducibility.

    Returns:
        tuple: Depth prediction (torch.Tensor) and normal prediction (torch.Tensor), both of shape (3, W, H).
    """
    


    # Set device

    device = input_tensor.device


    # Determine precision
    dtype = torch.float16 if half_precision else torch.float32

    # Load pipeline
    pipe = DepthNormalEstimationPipeline.from_pretrained(checkpoint_path, torch_dtype=dtype)
    

    try:
        pipe.enable_xformers_memory_efficient_attention()
    except:
        pass

    pipe = pipe.to(device)

    # Ensure input_tensor is on the correct device and has the right format
    input_tensor = input_tensor.to(device).float()
    input_image = Image.fromarray((input_tensor.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8))

    # Perform inference

    pipe_out = pipe(
            input_image,
            denoising_steps=denoise_steps,
            ensemble_size=ensemble_size,
            processing_res=processing_res,
            match_input_res=output_res_match_input,
            domain=domain,
            color_map=color_map,
            show_progress_bar=False,
        )

    depth_pred = torch.tensor(pipe_out.depth_np, device=device)
    normal_pred = torch.tensor(pipe_out.normal_np, device=device)

    # Convert depth and normal predictions to 3-channel tensors
    depth_pred_tensor = torch.stack([depth_pred] * 3, dim=0)
    normal_pred_tensor = torch.tensor(pipe_out.normal_colored).permute(2, 0, 1).to(device)

    return depth_pred_tensor, normal_pred_tensor

